package top.cardone.security.shiro.web.filter.authc.impl;

import com.google.gson.Gson;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
import org.springframework.core.task.AsyncListenableTaskExecutor;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.util.WebUtils;
import top.cardone.cache.Cache;
import top.cardone.context.ApplicationContextHolder;
import top.cardone.context.util.CodeExceptionUtils;
import top.cardone.context.util.TableUtils;
import top.cardone.core.CodeException;
import top.cardone.core.util.func.Func4;
import top.cardone.security.shiro.authc.impl.StatelessTokenImpl;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * @author yao hai tao
 * @date 16-2-2
 */
@Log4j2
public class StatelessAuthcFilterImpl extends AuthenticatingFilter {
    @Setter
    protected String credentialsParam = "token";

    @Setter
    protected String principalParam = "username";

    @Setter
    protected String lastOnlineTimesParam = "userLastOnlineTimes";

    @Setter
    protected String lastOnlineIpParam = "userLastOnlineIp";

    @Setter
    protected List<String> loginSuccessFuncBeanNames;

    @Setter
    private String cacheBeanName = "cardone.web.cache";

    @Setter
    private Long countUpperLimit = 1000L;
    @Setter
    private String taskExecutorBeanName = "slowTaskExecutor";

    private void log(ServletRequest request) {
        ApplicationContextHolder.getBean(AsyncListenableTaskExecutor.class, this.taskExecutorBeanName).submitListenable(() -> {
            if (countUpperLimit > 0L) {
                String rowKey = this.getClass().getName();

                String columnKey = ((HttpServletRequest) request).getServletPath();

                Long count = TableUtils.longAdderIncrementGetSum(rowKey, columnKey);

                if (count % countUpperLimit == 0) {
                    log.error(StringUtils.join("调用较频繁, rowKey: ", rowKey, ", columnKey: ", columnKey, ", count: ", count));
                }
            }
        });
    }

    @Override
    protected AuthenticationToken createToken(ServletRequest request, ServletResponse response) throws Exception {
        this.log(request);

        HttpServletRequest httpRequest = (HttpServletRequest) request;

        StatelessTokenImpl token = new StatelessTokenImpl();

        token.setCredentials(this.getParameter(httpRequest, credentialsParam));

        token.setPrincipal(this.getParameter(httpRequest, principalParam));

        return token;
    }

    private String getParameter(HttpServletRequest httpRequest, String name) {
        Cookie cookie = WebUtils.getCookie(httpRequest, name);

        String parameter = cookie == null ? null : cookie.getValue();

        if (StringUtils.isNotBlank(parameter)) {
            return parameter;
        }

        parameter = httpRequest.getHeader(name);

        if (StringUtils.isNotBlank(parameter)) {
            return parameter;
        }

        return httpRequest.getParameter(name);
    }

    private String getIpAddress(HttpServletRequest request) {
        if (request == null) {
            return null;
        }

        String ip = request.getHeader("x-forwarded-for");

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }

        if (StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip)) {
            return null;
        }

        return ip;
    }


    @Override
    protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, ServletResponse response) {
        if (Objects.nonNull(token.getPrincipal())) {
            ApplicationContextHolder.getBean(Cache.class).put(this.lastOnlineTimesParam, token.getPrincipal(), System.currentTimeMillis());

            String ip = getIpAddress((HttpServletRequest) request);

            if (StringUtils.isNotBlank(ip)) {
                ApplicationContextHolder.getBean(Cache.class).put(this.lastOnlineIpParam, token.getPrincipal(), ip);
            }
        }

        if (!CollectionUtils.isEmpty(loginSuccessFuncBeanNames)) {
            for (String loginSuccessFuncBeanName : loginSuccessFuncBeanNames) {
                Func4<Boolean, AuthenticationToken, Subject, ServletRequest, ServletResponse> loginSuccessFunc = ApplicationContextHolder.getBean(Cache.class, this.cacheBeanName).get(
                        AuthenticatingFilter.class.getName(), 1,
                        loginSuccessFuncBeanName,
                        () -> (Func4<Boolean, AuthenticationToken, Subject, ServletRequest, ServletResponse>) ApplicationContextHolder.getBean(loginSuccessFuncBeanName)
                );

                if (loginSuccessFunc != null) {
                    if (!loginSuccessFunc.func(token, subject, request, response)) {
                        return false;
                    }
                }
            }
        }

        return true;
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
        Subject subject = getSubject(request, response);

        if (subject.isAuthenticated()) {
            return true;
        }

        return executeLogin(request, response);
    }

    @Override
    protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, ServletRequest request, ServletResponse response) {
        if (!StringUtils.startsWith(request.getContentType(), org.springframework.http.MediaType.APPLICATION_JSON_VALUE) && RequestMethod.GET.name().equalsIgnoreCase(((HttpServletRequest) request).getMethod())) {
            return super.onLoginFailure(token, e, request, response);
        }

        response.setCharacterEncoding(StandardCharsets.UTF_8.name());
        response.setContentType(org.springframework.http.MediaType.APPLICATION_JSON_VALUE);

        try (Writer out = response.getWriter()) {
            String requestURI = getPathWithinApplication(request);

            Map<String, String> errorInfo = CodeExceptionUtils.newMap(requestURI, new CodeException("login failure", null, "登录失败", e));

            String json = ApplicationContextHolder.getBean(Gson.class).toJson(errorInfo);

            errorInfo.clear();

            out.write(json);
        } catch (java.io.IOException ex) {
            log.error(ex.getMessage(), ex);
        }

        return false;
    }
}
